#!/usr/bin/env python
# coding: utf-8
# based on https://github.com/andersbll/nnet
# modified by heibanke
"""
mlp: use nn to identify digital number
"""
import shelve

import numpy as np
from matplotlib import pyplot as plt

from nnet.layers import Linear, Activation, MSECostLayer
from nnet.neuralnetwork import neuralnetwork

#import sklearn.datasets

#digits = sklearn.datasets.load_digits()

# write to file
#s = shelve.open('sk_digits.dat') 
#s['digits'] = digits
#s.close()

def show_pic(sample):
    """
    show pic with sample
    """
    num_sample = sample.shape[0]
    pic_size = 8

    img_data = np.zeros((pic_size*num_sample, pic_size))
    for i in xrange(num_sample):
        img_data[pic_size*i:pic_size*(i+1), :pic_size] = sample[i, :].reshape(pic_size, pic_size)


    plt.imshow(np.floor(img_data*255), cmap="gray")
    plt.show()


NUM_RANGE = 2

# read from file
S = shelve.open('sk_digits.dat') 
DIGITS = S['digits']
S.close()

X_train = DIGITS.data
X_train /= np.max(X_train)
y_train = DIGITS.target

idx, = np.where(y_train<NUM_RANGE) 
y_train = y_train[idx]
X_train = X_train[idx,:]
n_classes = np.unique(y_train).size

# Setup multi-layer perceptron 
nn = neuralnetwork(
    layers=[
        Linear(
            n_out=32,
            weight_scale=0.1,
        ),
        Activation('relu'), 
        Linear(
            n_out=n_classes,
            weight_scale=0.1,
        ),
        Activation('tanh'),        
    ],
    cost=MSECostLayer(),
)

TRAIN_NUM = len(y_train)*2/3

# Train neural network
print 'Training neural network'
nn.train(X_train[:TRAIN_NUM], y_train[:TRAIN_NUM], learning_rate=0.01, max_iter=20, batch_size=32)

# Evaluate on training data
error = nn.error(X_train[TRAIN_NUM:], y_train[TRAIN_NUM:])
print 'Training error rate: %.4f' % error


